{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# The below code is inspired from the work done by a kaggle user \"Keita Kurita\". \n",
    "\n",
    "* The idea is to generate a fastai based databunch which will be used to finally fine tune a bert model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
    "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from pathlib import Path\n",
    "from typing import *\n",
    "\n",
    "import torch\n",
    "import torch.optim as optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai import *\n",
    "from fastai.vision import *\n",
    "from fastai.text import *\n",
    "from fastai.callbacks import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: pytorch-pretrained-bert in /opt/conda/lib/python3.6/site-packages (0.6.2)\n",
      "Requirement already satisfied: torch>=0.4.1 in /opt/conda/lib/python3.6/site-packages (from pytorch-pretrained-bert) (1.2.0)\n",
      "Requirement already satisfied: requests in /opt/conda/lib/python3.6/site-packages (from pytorch-pretrained-bert) (2.22.0)\n",
      "Requirement already satisfied: tqdm in /opt/conda/lib/python3.6/site-packages (from pytorch-pretrained-bert) (4.32.1)\n",
      "Requirement already satisfied: regex in /opt/conda/lib/python3.6/site-packages (from pytorch-pretrained-bert) (2019.8.19)\n",
      "Requirement already satisfied: numpy in /opt/conda/lib/python3.6/site-packages (from pytorch-pretrained-bert) (1.17.0)\n",
      "Requirement already satisfied: boto3 in /opt/conda/lib/python3.6/site-packages (from pytorch-pretrained-bert) (1.9.212)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.6/site-packages (from requests->pytorch-pretrained-bert) (2019.6.16)\n",
      "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/lib/python3.6/site-packages (from requests->pytorch-pretrained-bert) (1.24.2)\n",
      "Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/lib/python3.6/site-packages (from requests->pytorch-pretrained-bert) (3.0.4)\n",
      "Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/lib/python3.6/site-packages (from requests->pytorch-pretrained-bert) (2.8)\n",
      "Requirement already satisfied: s3transfer<0.3.0,>=0.2.0 in /opt/conda/lib/python3.6/site-packages (from boto3->pytorch-pretrained-bert) (0.2.1)\n",
      "Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /opt/conda/lib/python3.6/site-packages (from boto3->pytorch-pretrained-bert) (0.9.4)\n",
      "Requirement already satisfied: botocore<1.13.0,>=1.12.212 in /opt/conda/lib/python3.6/site-packages (from boto3->pytorch-pretrained-bert) (1.12.212)\n",
      "Requirement already satisfied: python-dateutil<3.0.0,>=2.1; python_version >= \"2.7\" in /opt/conda/lib/python3.6/site-packages (from botocore<1.13.0,>=1.12.212->boto3->pytorch-pretrained-bert) (2.8.0)\n",
      "Requirement already satisfied: docutils<0.16,>=0.10 in /opt/conda/lib/python3.6/site-packages (from botocore<1.13.0,>=1.12.212->boto3->pytorch-pretrained-bert) (0.15.2)\n",
      "Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.6/site-packages (from python-dateutil<3.0.0,>=2.1; python_version >= \"2.7\"->botocore<1.13.0,>=1.12.212->boto3->pytorch-pretrained-bert) (1.12.0)\n"
     ]
    }
   ],
   "source": [
    "!pip install pytorch-pretrained-bert"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Config(dict):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        for k, v in kwargs.items():\n",
    "            setattr(self, k, v)\n",
    "    \n",
    "    def set(self, key, val):\n",
    "        self[key] = val\n",
    "        setattr(self, key, val)\n",
    "\n",
    "config = Config(\n",
    "    bert_model_name=\"bert-base-uncased\",\n",
    "    max_lr=3e-5,\n",
    "    epochs=10,\n",
    "    use_fp16=True,\n",
    "    bs=32,\n",
    "    discriminative=False,\n",
    "    max_seq_len=256,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "_cell_guid": "79c7e3d0-c299-4dcb-8224-4455121ee9b0",
    "_uuid": "d629ff2d2480ee46fbb7e2d37f6b5fab8052498a"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 231508/231508 [00:00<00:00, 5762186.93B/s]\n"
     ]
    }
   ],
   "source": [
    "from pytorch_pretrained_bert import BertTokenizer\n",
    "bert_tok = BertTokenizer.from_pretrained(\n",
    "    config.bert_model_name,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _join_texts(texts:Collection[str], mark_fields:bool=False, sos_token:Optional[str]=BOS):\n",
    "    \"\"\"Borrowed from fast.ai source\"\"\"\n",
    "    if not isinstance(texts, np.ndarray): texts = np.array(texts)\n",
    "    if is1d(texts): texts = texts[:,None]\n",
    "    df = pd.DataFrame({i:texts[:,i] for i in range(texts.shape[1])})\n",
    "    text_col = f'{FLD} {1} ' + df[0].astype(str) if mark_fields else df[0].astype(str)\n",
    "    if sos_token is not None: text_col = f\"{sos_token} \" + text_col\n",
    "    for i in range(1,len(df.columns)):\n",
    "        #text_col += (f' {FLD} {i+1} ' if mark_fields else ' ') + df[i]\n",
    "        text_col += (f' {FLD} {i+1} ' if mark_fields else ' ') + df[i].astype(str)\n",
    "    return text_col.values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class FastAiBertTokenizer(BaseTokenizer):\n",
    "    \"\"\"Wrapper around BertTokenizer to be compatible with fast.ai\"\"\"\n",
    "    def __init__(self, tokenizer: BertTokenizer, max_seq_len: int=128, **kwargs):\n",
    "        self._pretrained_tokenizer = tokenizer\n",
    "        self.max_seq_len = max_seq_len\n",
    "\n",
    "    def __call__(self, *args, **kwargs):\n",
    "        return self\n",
    "\n",
    "    def tokenizer(self, t:str) -> List[str]:\n",
    "        \"\"\"Limits the maximum sequence length\"\"\"\n",
    "        return [\"[CLS]\"] + self._pretrained_tokenizer.tokenize(t)[:self.max_seq_len - 2] + [\"[SEP]\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(7628, 2) (7628, 2)\n"
     ]
    }
   ],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "DATA_ROOT = Path(\"..\") / \"input\" / \"sentiment\"\n",
    "\n",
    "train, test = [pd.read_excel(DATA_ROOT / fname) for fname in [\"Data_Train.xlsx\", \"Data_Test.xlsx\"]]\n",
    "val = train\n",
    "print(train.shape, val.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [],
   "source": [
    "fastai_bert_vocab = Vocab(list(bert_tok.vocab.keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "metadata": {},
   "outputs": [],
   "source": [
    "fastai_tokenizer = Tokenizer(tok_func=FastAiBertTokenizer(bert_tok, max_seq_len=config.max_seq_len), pre_rules=[], post_rules=[])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>STORY</th>\n",
       "      <th>SECTION</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>But the most painful was the huge reversal in ...</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>How formidable is the opposition alliance amon...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Most Asian currencies were trading lower today...</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>If you want to answer any question, click on ‘...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>In global markets, gold prices edged up today ...</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                               STORY  SECTION\n",
       "0  But the most painful was the huge reversal in ...        3\n",
       "1  How formidable is the opposition alliance amon...        0\n",
       "2  Most Asian currencies were trading lower today...        3\n",
       "3  If you want to answer any question, click on ‘...        1\n",
       "4  In global markets, gold prices edged up today ...        3"
      ]
     },
     "execution_count": 98,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = pd.get_dummies(train, columns=['SECTION'])\n",
    "val = pd.get_dummies(val, columns=['SECTION'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "metadata": {},
   "outputs": [],
   "source": [
    "label_cols = ['SECTION_0', 'SECTION_1', 'SECTION_2', 'SECTION_3']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BertTokenizeProcessor(TokenizeProcessor):\n",
    "    def __init__(self, tokenizer):\n",
    "        super().__init__(tokenizer=tokenizer, include_bos=False, include_eos=False)\n",
    "\n",
    "class BertNumericalizeProcessor(NumericalizeProcessor):\n",
    "    def __init__(self, *args, **kwargs):\n",
    "        super().__init__(*args, vocab=Vocab(list(bert_tok.vocab.keys())), **kwargs)\n",
    "\n",
    "def get_bert_processor(tokenizer:Tokenizer=None, vocab:Vocab=None):\n",
    "    return [BertTokenizeProcessor(tokenizer=tokenizer),\n",
    "            NumericalizeProcessor(vocab=vocab)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BertDataBunch(TextDataBunch):\n",
    "    @classmethod\n",
    "    def from_df(cls, path:PathOrStr, train_df:DataFrame, valid_df:DataFrame, test_df:Optional[DataFrame]=None,\n",
    "                tokenizer:Tokenizer=None, vocab:Vocab=None, classes:Collection[str]=None, text_cols:IntsOrStrs=1,\n",
    "                label_cols:IntsOrStrs=0, label_delim:str=None, **kwargs) -> DataBunch:\n",
    "        \"Create a `TextDataBunch` from DataFrames.\"\n",
    "        p_kwargs, kwargs = split_kwargs_by_func(kwargs, get_bert_processor)\n",
    "        # use our custom processors while taking tokenizer and vocab as kwargs\n",
    "        processor = get_bert_processor(tokenizer=tokenizer, vocab=vocab, **p_kwargs)\n",
    "        if classes is None and is_listy(label_cols) and len(label_cols) > 1: classes = label_cols\n",
    "        src = ItemLists(path, TextList.from_df(train_df, path, cols=text_cols, processor=processor),\n",
    "                        TextList.from_df(valid_df, path, cols=text_cols, processor=processor))\n",
    "        src = src.label_for_lm() if cls==TextLMDataBunch else src.label_from_df(cols=label_cols, classes=classes)\n",
    "        if test_df is not None: src.add_test(TextList.from_df(test_df, path, cols=text_cols))\n",
    "        return src.databunch(**kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [],
   "source": [
    "databunch = BertDataBunch.from_df(\".\", train, val, test,\n",
    "                  tokenizer=fastai_tokenizer,\n",
    "                  vocab=fastai_bert_vocab,\n",
    "                  text_cols=\"STORY\",\n",
    "                  label_cols=label_cols,\n",
    "                  bs=config.bs,\n",
    "                  collate_fn=partial(pad_collate, pad_first=False, pad_idx=0),\n",
    "             )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pytorch_pretrained_bert.modeling import BertConfig, BertForSequenceClassification\n",
    "bert_model = BertForSequenceClassification.from_pretrained(config.bert_model_name, num_labels=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_func = nn.BCEWithLogitsLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai.callbacks import *\n",
    "\n",
    "learner = Learner(\n",
    "    databunch, bert_model,\n",
    "    loss_func=loss_func,\n",
    ")\n",
    "if config.use_fp16: learner = learner.to_fp16()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n"
     ]
    }
   ],
   "source": [
    "learner.lr_find()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAEKCAYAAAD9xUlFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3Xl8VfWd//HXJztJyAJJgARCQgIooIAgiAtitW5TRad2RruMXaa2v3GZtmP7a3+d6WKnm13sbnWs7bS1dVqt1m3UuuBSRQmr7GuAJGRPyL5/f3/cm3gbEwiQk3Pvzfv5eNyH955z7j2fryF533O+53y/5pxDREQEIMbvAkREJHwoFEREZIBCQUREBigURERkgEJBREQGKBRERGSAQkFERAYoFEREZIBCQUREBsT5XcCJysrKcgUFBX6XISISUdavX1/rnMs+3nYRFwoFBQWUlJT4XYaISEQxs4Mj2U6nj0REZIBCQUREBigURERkgEJBREQGKBRERGSAQkFERAYoFEREZIBCQUQkAvzwuT38dW+t5/tRKIiIhLmj7d384PndlJQ2eL4vhYKISJjbcKgB5+DsgkzP96VQEBEJc+tLG4iNMRblZ3i+L4WCiEiYW1daz4LcNJITvB+uTqEgIhLGunr62HS4kaUFk8ZkfwoFEZEwtrXiKJ09fSyd6X1/AigURETC2vrgFUdLxqCTGRQKIiJhbV1pPQWTk8mZmDQm+1MoiIiEKeccJQcbxqw/ARQKIiJha39tK/WtXWNyf0I/hYKISJgqKa0HYMlMHSmIiIx7JaUNZCbHU5SdMmb7VCiIiISp/v4EMxuzfXoWCmZ2v5lVm9nWYdafZmavm1mnmd3uVR0iIpGoprmTA7WtY9qfAN4eKfwKuPwY6+uB24DveliDiEhEWn8w0J8wllcegYeh4Jx7mcAf/uHWVzvn1gHdXtUgIhKp1pU2kBgXw4Lc9DHdb0T0KZjZTWZWYmYlNTU1fpcjIuK5koMNLJyRQULc2P6ZjohQcM7d65xb6pxbmp2d7Xc5IiKeauvqYVv50THvT4AICQURkfFk0+FGevrcmPcngEJBRCTsrC9twAzOyh/7IwXPZmwws98Dq4AsMysDvgzEAzjnfm5mU4ESIA3oM7NPAfOcc01e1SQiEgl2VDaRPymZ9AnxY75vz0LBOXfDcdZXAtO92r+ISKTaVdnM3CkTfdm3Th+JiISRju5eSuvaOG2qQkFEZNzbW91Cb59j7tQ0X/avUBARCSO7KpsBmDs11Zf9KxRERMLI7qpmEuJiKJg8diOjhlIoiIiEkZ2VzRRnpxIX68+fZ4WCiEgY2VXZ7FsnMygURETCxtG2biqbOpijUBARkV1V/Z3MCgURkXFvV2VgQAedPhIREXZWNpOWFMfUtCTfalAoiIiEiV2VzcydOnFM52QeTKEgIhIGnHPsqmr2tT8BFAoiImHhyNEOmjt6fBveop9CQUQkDPQPb+FnJzMoFEREwsLOYCjMyVEoiIiMe7sqm5iWnkR68thPrBNKoSAiEgZ2VbX43skMCgUREd919/axr1qhICIiQGltK129fb5NwRlKoSAi4rOdlf6PedRPoSAi4rPdVc3ExhjFOf7MthZKoSAi4rOdlc0UZqWQGBfrdykKBRERv+2qbA6L/gRQKIiI+Kq9q5dD9W3MifZQMLP7zazazLYOs97M7EdmttfMtpjZWV7VIiISrsoa2gAoyEr2uZIAL48UfgVcfoz1VwCzg4+bgLs9rEVEJCyVN7YDkJcxwedKAjwLBefcy0D9MTZZDfzaBawFMsxsmlf1iIiEo4FQyIzyUBiBPOBwyOuy4DIRkXGjvKGduBgjZ6J/s62F8jMUhppayA25odlNZlZiZiU1NTUelyUiMnbKG9uZmp5EbIx/s62F8jMUyoAZIa+nAxVDbeicu9c5t9Q5tzQ7O3tMihMRGQvlDe1h058A/obCY8A/Ba9COgc46pw74mM9IiJjrqKxPWz6EwDivPpgM/s9sArIMrMy4MtAPIBz7ufAU8CVwF6gDfiIV7WIiISj7t4+Kps6mB5GRwqehYJz7objrHfAzV7tX0Qk3FUe7aDPhc+VR6A7mkVEfNN/OWpuGB0pKBRERHxS3hBeN66BQkFExDcVOlIQEZF+5Y3tZKUmkhTv/5DZ/RQKIiI+KQ+zy1FBoSAi4pvAjWvhMbxFP4WCiIgPnHOBI4Uw6k8AhYKIiC/qWrvo7OlTKIiISMjlqJnhMblOP4WCiIgPwm1ynX4KBRERH7x9pKBQEBEZ98ob20lNjCMtybMh6E6KQkFExAdlwXkUzMJjcp1+CgURER+E2zwK/RQKIiI+CMd7FEChICIy5lo6ezja3q0jBRERefvKo3AaHbWfQkFEZIyVN7YB4XePAigURETGXHljBwDTdfpIRETKG9pJiI0hOzXR71LeQaEgIjLGyhvbmZaRRExMeN2jAAoFEZExV97QFpb9CaBQEBEZc+WN7WF55REoFERExlRXTx/VzZ06UhAREag82oFz4Tc6aj9PQ8HMLjezXWa218w+P8T6mWb2vJltMbM1Zjbdy3pERPxWFrxHYfp4O1Iws1jgp8AVwDzgBjObN2iz7wK/ds6dCdwBfNOrekREwkG4zqPQz8sjhWXAXufcfudcF/AgsHrQNvOA54PPXxxivYhIVClvbMcMpqYn+V3KkLwMhTzgcMjrsuCyUJuB9wafXwtMNLPJgz/IzG4ysxIzK6mpqfGkWBGRsbCnuoVpaUkkxsX6XcqQvAyFoe7KcINe3w5caGYbgQuBcqDnHW9y7l7n3FLn3NLs7OzRr1REZAz09Pbx6p5azi3O8ruUYXk5D1wZMCPk9XSgInQD51wF8PcAZpYKvNc5d9TDmkREfLO5rJGj7d2smhu+X269PFJYB8w2s0IzSwCuBx4L3cDMssysv4YvAPd7WI+IiK/W7KohNsa4oHgchoJzrge4BXgG2AH8wTm3zczuMLOrg5utAnaZ2W5gCvB1r+oREfHbi7uqOSs/g/TkeL9LGZaXp49wzj0FPDVo2ZdCnj8EPORlDSIi4aC6uYOt5U189rK5fpdyTLqjWURkDLy8uxaAC+eE76kjGGEomFmRmSUGn68ys9vMLMPb0kREoseaXdVkT0xkfm6a36Uc00iPFB4Ges2sGPgFUAj8zrOqRESiSE9vH6/sqeXCOdmYhd8cCqFGGgp9wY7ja4EfOOc+DUzzriwRkeix6XD4X4rab6Sh0G1mNwA3Ak8El4Vv97mISBiJhEtR+400FD4CrAC+7pw7YGaFwG+9K0tEJHqs2R3+l6L2G1EoOOe2O+duc8793swygYnOuW95XJuISMTrvxR11dwcv0sZkZFefbTGzNLMbBKBQex+aWbf97Y0EZHIFymXovYb6emjdOdcE4Fxin7pnFsCXOJdWSIi0eHFCLkUtd9IQyHOzKYB/8DbHc0iInIMPb19vLK7JiIuRe030lC4g8AYRvucc+vMbBawx7uyREQi357qFpo6ejg/jIfKHmxEYx855/4I/DHk9X7enhxHRESGUFrbCkBxTqrPlYzcSDuap5vZI2ZWbWZVZvawmU33ujgRkUh2oC4QCjMnJ/tcyciN9PTRLwnMhZBLYErNx4PLRERkGAdr28hKTWBiUvjfn9BvpKGQ7Zz7pXOuJ/j4FRAZ11eJiPjkQF0rBZNT/C7jhIw0FGrN7INmFht8fBCo87IwEZFIV1rbSkFWdIbCRwlcjloJHAGuIzD0hYiIDKG1s4fq5k4KozEUnHOHnHNXO+eynXM5zrlrCNzIJiIiQyiNwE5mOLWZ1z4zalWIiESZg3VtAFHbpzCUyLg9T0TEBweC9yhEa5/CUNyoVSEiEmVKa1vJnphIauKI7hEOG8es1syaGfqPvwETPKlIRCQKlNa1Uhhhp47gOKHgnJs4VoWIiEST0ro2VkXIcNmhTuX0kYiIDKGls4ea5s6I608Aj0PBzC43s11mttfMPj/E+nwze9HMNprZFjO70st6RETGQv9AeJF2jwJ4GApmFgv8FLgCmAfcYGbzBm3278AfnHOLgeuBn3lVj4jIWOm/RyHSLkcFb48UlgF7nXP7nXNdwIPA6kHbOKB/OqJ0oMLDekRExsTAPQpZkXXjGoxwPoWTlAccDnldBiwftM1XgGfN7FYgBU3xKSJR4EBtKzkTE0lOiKzLUcHbI4Whbm4bfHnrDcCvnHPTgSuB35jZO2oys5vMrMTMSmpqajwoVURk9ETiQHj9vAyFMmBGyOvpvPP00MeAPwA4514HkoB3zFvnnLvXObfUObc0OzvyLvESkfElUu9RAG9DYR0w28wKzSyBQEfyY4O2OQRcDGBmpxMIBR0KiEjEau7opralS0cKgznneoBbgGeAHQSuMtpmZneY2dXBzf4N+LiZbQZ+D3zYOafhM0QkYvV3MhdGYCczeNvRjHPuKeCpQcu+FPJ8O3CelzWIiIyl/oHwZur0kYiI9N+4Fon3KIBCQURkVJXWtTE1LYkJCbF+l3JSFAoiIqOotK41Im9a66dQEBEZRaW1rRE55lE/hYKIyChp6uimrrUrYjuZQaEgIjJqDtZG5rzMoRQKIiKj5EBd5A6Z3U+hICIySkoH7lFQR7OIyLi3t7qF3PQkkuIj83JUUCiIiIyK6qYOntlWyfmz3zGmZ0RRKIiIjIKfv7Sfnj7HzRcV+13KKVEoiIicouqmDh544yB/vzgvoi9HBYWCiMgp6z9KuOVdkX2UAAoFEZFTEk1HCaBQEBE5JdF0lAAKBRGRk1bdHDhKuDZKjhJAoSAictLu6T9KiPArjkIpFERETkJNcye/XRs4SojU+ZiHolAQETkJj2+uoLOnj0+snOV3KaNKoSAichKefOsIp09LY/aUiX6XMqoUCiIiJ6iisZ31Bxt4z5nT/C5l1CkURERO0JNbjgAoFEREBJ7YUsEZeelRcxlqqDi/CwhnzjnePFBPZVMHs7JSKcxOITXx2P/LnHN09vTR2tlDfFwME+JjiY+NGVhX39rFwfo2DtW1UdbQRmpiHDOzUpg5KZnpmckkxCmnRcLZobo2Npcd5QtXnOZ3KZ5QKAyhobWLhzeU8bs3D7G/pvVv1k1JS2Tm5BQM6OzpCzy6e2nv7qWls4e2rl56+9zfvCc+1pgQH0tvn6O1q3fY/cYY5GVO4LSpaZw+LY150yZy+rQ0ZmQmExNjXjRVRE7QE29VAPB3UXjqCDwOBTO7HPghEAvc55z71qD1dwEXBV8mAznOuQwvawrlnKOxrZvDDW0crm+nrKGNrRVNPLOtkq6ePhbnZ/Cd685kQV46pbWt7K9tZX9NK4fr28AgbUI8iXExJMXHkhQXQ0piHKmJcSQnxpIcH0tPn6Otq5e2rl46ugNhkD8pmZmTA4/pmck0dXRzqK6Ng3VtHKxvY39NCzuONPH8jir6s2VCfCxFOSkUZ6cye8pEinNSmTctjemZEzBTWIiMpSe3HGFxfgbTMyN3drVj8SwUzCwW+CnwbqAMWGdmjznntvdv45z7dMj2twKLvarnaHs3W8uPsruqmd1VLeypamZPdQtH27v/ZrtJKQn849IZvH95PqdPSxtYHvp8NCXFx5IzMYmlBZP+Znl7Vy+7q5rZcaSJPdUt7Klu4c0D9Ty6qWJgm7SkOOblpjE/N52UxDga27poaOumsa2Lju5e5uems6xwEmcXTCJ7YqIn9YuMJwdqW9lW0cR/vGee36V4xssjhWXAXufcfgAzexBYDWwfZvsbgC97VcyaXdX864ObAMhIjmdOzkTec+Y0CrNSmDEpmRmZyUyfNIG0pHivSjghExJiWTgjg4Uz/vbAqaWzhz1VzWw/0sS2isDjt2sP0tnTR/qEeDKT48lITiA+1nhw3SF+9VopALOyUlhWOInlsyaxvHAyuRkTfGiVSGR7YnPgS9mVZ0z1uRLveBkKecDhkNdlwPKhNjSzmUAh8MIw628CbgLIz88/qWLOK87igX9ezuwpqWSnJkbsaZfUxDgW52eyOD9zYFl/H0bsoH6Hrp4+tlYc5c0D9bx5oJ4n3zrCg+sCP5IZkyawvHAy58yazDmzJkXtobDIaHpiyxHOLshkWnr0fqnyMhSG+qvrhlgGcD3wkHNuyF5Y59y9wL0AS5cuHe4zjikrNZGs4ug8hTI4DPolxMVwVn4mZ+Vn8skLi+jtc+w40sQbB+p5Y38dz+2o4qH1ZQDkZUzgnFmTWV44icX5GRRlp6pzWyTEnqpmdlU189Wr5/tdiqe8DIUyYEbI6+lAxTDbXg/c7GEtQiA8FuSlsyAvnY+dX0hfn2N3dTNr99Wxdn89L+ys4uENgZCYmBjHwhkZLJqRwXnFWSwtyBy4tFZkPHHOUdXUya9fP0iMwRVRfOoIwJw7qS/ex/9gszhgN3AxUA6sA97vnNs2aLu5wDNAoRtBMUuXLnUlJSUeVCx9fY4Dda1sPNTIpsMNbDzUyM7KZnr7HBMT4zh/dhYXzc3h/NlZTEtPithTcCLH09jWxfee3c3WiqPsrWqhubMHgEtOz+G+G8/2ubqTY2brnXNLj7edZ0cKzrkeM7uFwB/8WOB+59w2M7sDKHHOPRbc9AbgwZEEgngrJsYoyk6lKDuV65ZMBwId23/dW8uaXdW8uLOG/91aCQROxy3IS+OM4JHH+cVZpBznxj6RSPGD5/bwuzcPcXZBJtcszmPOlFSKcyayOH/Mrpj3jWdHCl7RkYJ/nHPsONLMGwfq2FrexNbyo+ypbqbPwcSkOG5Yls+N5xaQpyubJIJVN3VwwZ0vsnpRLndet9DvckaN70cKEn3MjHm5aczLffuejfauXjYebuCBNw7xi1cP8ItXD3D5/Kl85LwClszM1CkmiTj3vByYTe3mKJpN7UQoFOSUTEiI5dyiLM4tyqK8sZ1fv1bK7988xJNvHWHm5GSuWZTHNYvzKIyimakketU0d/LAGwdZvSg3Kge7GwmdPpJR19rZw1NvHeHRTeW8tq8O52DRjAyuXZzHVQtzmZSS4HeJIkP65lM7+K9X9vPcZy5kVnaq3+WMqpGePlIoiKcqj3bw2OZy/rShnJ2VzcTFGKvmZnPN4jwuOX0KSfGxfpcoAkBdSyfnf/tFLps/hR9c79mIO75Rn4KEhanpSdy0soibVhaxs7KJRzaU8+imcp7bUU1qYhwXzs3m0nlTWDUnh/Tk8BhiRMan+149QEdPL7e8a3z2JfRTKMiYOW1qGl+4Mo3PXX4aa/fX8fjmCp7bUc2TW44QF2MsK5zEBbOzWVaYyRl5GZpbQsZMQ2sXv36tlPecmUtxTnTNuXyiFAoy5mJjjPOKszivOIu+Psemskae217Fczuq+PbTOwFIjIth0YwMls+azAeW5zMlLcnnqiVaNXd0c9dzu2nt6uXWcX6UAOpTkDBT29JJSWkD60rrKSmt563yo8THxvDBc2byyQuLNAS4nLDGti6aO3roc44+B33BGRD/ureWV/fUsvFwI719jmsX53HXPy7yu1zPqKNZosLBulZ+/MJe/rShjIS4GG5cUcDHV84iK1XhIO/knGNnZTPrDzaw4VBgqJYDta1DbmsGZ+alc/7sLM4vzubsgkzionh8L4WCRJX9NS38+IW9PLqpnPiYGC5bMJUPLM9neeEk3SAnOOdYs6uG7z67i20VTQBMTklgcX4mZ83MIGdiEjEGMWaYQUpCHEsLMslIHj+XRysUJCrtq2nhN68f5E8bymjq6KEoO4UbluVz3ZLp4+oXXN722r5avvfsbtYfbCB/UjKfvLCI84onkz8pWV8YQigUJKq1d/XyxJYKfvfmITYeaiQxLob3nJnLB8/JZ9GMDP0xGAc2HW7kO8/s5K9765ialsRtF8/mfUuna4j3YSgUZNzYXtHEA28c5NGN5bR29TJvWhofOa+AaxfnRfU54vFqd1Uz331mF89ur2JySgL/clExH1ierxshj0OhIONOS2cPj24s57drD7KzspninFRuv3Qul82foiOHCNbW1UN5Qztlje08vrmCRzaWk5oQx8dXzuKj5xeSqiHbR0ShIOOWc45ntlXxnWd2sq+mlUUzMvi/l5/GiqLJfpcmx9HY1sW60gbe2F/H+kMNHKxro761a2B9YlwMHz63gE9eWESmxtA6IQoFGfd6evt4eEMZd/1lD5VNHSwrmMQnLpzFRXNzNP90mHlhZxV3Pr2LXVXNOBeYX3zRjAyKc1LJy5jA9MzAoyg7VRcUnCSFgkhQR3cvvwvO91De2E5xTio3rZzF6kW5JMbpPLTfnHOs+u4a+pzjH5bMYPmsyZw5PV19BKNMoSAySHdvH09uOcI9L+9nx5EmMpPjeddpU3j3vCmsnJNFcoLOTfth/cEG3nv3a3znujN539IZfpcTtTRKqsgg8bExXLM4j9WLcnllTy2PbCznuR1VPBy8W/r84iw+tGImq+Zkq2N6DD2ysYyk+BiuOGOa36UICgUZh8yMlXOyWTknm+7ePtaV1vOX7VU8vbWSj/xyHUtmZvJvl87h3KIsv0uNel09fTyx5QiXzpuqq4jChC7ilnEtPjaGc4uy+PJV83npsxfx9WsXUN7Qzvv/6w0+cN9aSkrribRTrJFkza5qGtu6uXZxnt+lSJCiWSQoIS6GDyyfyXvPms4Dbxzi7jV7ue7nr3NGXjofWjGTqxfmqvNzlD2ysZys1AQumK2jsnChIwWRQZLiY/nY+YW8/LmL+No1C+jo7uVzD23hnG8+zzee2sG+mha/S4wKR9u6eX5HNVctzNWd52FERwoiw0hOiOND58zkg8vzWbu/nt+sLeUXrx7g3pf3c0ZeOqsX5XLVwlxNAHSSntp6hK7ePp06CjOehoKZXQ78EIgF7nPOfWuIbf4B+ArggM3Oufd7WZPIiTIzVhRNZkXRZKqbOnhscwV/3lTBfz65g288tYMVRZNZvSiPyxdMJS1J80yP1CMbyinKTuGMvHS/S5EQnt2nYGaxwG7g3UAZsA64wTm3PWSb2cAfgHc55xrMLMc5V32sz9V9ChIu9la3BAOinIN1bSTExXDJ6TmsXpTHRXNzNMf0MRyub+OCO1/ks5fN5eaLNAXmWAiH+xSWAXudc/uDBT0IrAa2h2zzceCnzrkGgOMFgkg4Kc5J5TPvnsOnL5nNpsONPLqxnCe2HOGptyopzErhq1fPZ+WcbL/LDEt/3lQOwNULc32uRAbz8qtMHnA45HVZcFmoOcAcM/urma0Nnm4SiShmxuL8TL66egFr/9/F3POhJQD80/1v8i8PrKeisd3nCsOLc44/bSxnWeEkZkxK9rscGcTLI4WhbgkdfK4qDpgNrAKmA6+Y2QLnXOPffJDZTcBNAPn5+aNfqcgoiY+N4bL5U1k1N5v7XjnAj1/Yw4s7a7jt4tn88wWFUTcBzGv7aimtbSM5ITb4iCMlMZbJKYlMSk0gJSEWM6O9q5fX99fy4s4a1uyu5nB9O59YOcvv8mUIXoZCGRA6kMl0oGKIbdY657qBA2a2i0BIrAvdyDl3L3AvBPoUPKtYZJQkxsVy80XFXL0wl689sZ1vP72TxzdXcOd1Z7IgCjpWD9W1cccT23hux7HP+CbGxZCVmkhNSyddPX1MiI/lvOLJ3LyqmPct0ThH4cjLjuY4Ah3NFwPlBP7Qv985ty1km8sJdD7faGZZwEZgkXOubrjPVUezRKJntlXy749upb61i0+snMVtF8+OyBvhOrp7uXvNPu5+aR9xMca/Xjybqxbm0t7dS3tXL21dvbR0dlPf2k1dSyf1rV3UtnSRkRzPqrnZnF0wKSLbHQ1872h2zvWY2S3AMwQuSb3fObfNzO4ASpxzjwXXXWpm24Fe4LPHCgSRSHXZ/KmcUziZ/3xyOz9bs4+nt1XyzWvPYPmsyJj4p6e3j8e3VPD9v+zmcH07Vy3M5YtXns7UdN2jEW00dLbIGHtlTw2ff/gtyhvbuXz+VD5/xWkUZKW8Y7v+300/R2zt7u3jkY3l/OzFvZTWtXHa1Il86ap5GiwwAmk+BZEw1t7Vy32v7Oful/bR3dvHh84p4NZ3FVPT0slre2t5fX8da/fXMyklga9cPZ8LPbi0tbqpg/UHGygJPg7Xt5GVmsCUtCSmpCWRmRzP/26tpKyhnfm5adx28WzeffoUzVoXoRQKIhGguqmD7/9lN38oOYwD+n8d8zImsKJoMhsONrC/tpW/O3MaX3rPvFEZUqO5o5sP/eJNNh0OXOSXGBfDwukZFOWkUN/aRWVTJ9VNHVQ3d3JGXjq3XVzMRXNzNMdEhFMoiESQHUea+NOGMmbnTGRF0eSB6/c7e3q556X9/OTFvSTGxnD7ZXP54DkziT2Fb+uf/p9N/HlTObdfNpcVsyYzPzd9yLuvnXMKgiiiUBCJIqW1rfzHn7fyyp5a5kxJ5QtXnM6quSc+Q9yfN5Xzrw9u4lOXzOZTl8zxqFoJRyMNhei6k0YkShVkpfDrjy7j7g+cRVdPHx/51To++Is32Fp+dMSfcbi+jX9/ZCtLZ2Zyi8YbkmEoFEQihJlxxRnTePbTF/Llq+axvaKJq37yKrf/cTP1rV3HfG9Pbx+f+p9NANz1j4s0f4EMS/8yRCJMQlwMHzmvkJc+dxGfWFnEnzeVc8n3X+KRjWXDTh36kxf3sv5gA/957QKNNyTHpFAQiVBpSfF8/orTeOLWC5g5OZlP/89mbvzlOg7Xt9Hb5zhU18YLO6v4yQt7+NHze7h2cR6rF2lCGzk2dTSLRIHePsdv1x7kzqd30t3nMKCzp29g/YK8NH7/8XOYqEmAxi3fh7kQkbETG2PceG4B7543hXte2kdifCxF2SkU56QyKyuVzJQEv0uUCKFQEIkiuRkT+OrqBX6XIRFMfQoiIjJAoSAiIgMUCiIiMkChICIiAxQKIiIyQKEgIiIDFAoiIjJAoSAiIgMibpgLM6sBDoYsSgeGGj948PJjvR7ueRZQe4olD1ffiWw31LqRLFMbw6d9w60/3rKRtDdc2jhefxeHWh6ObZzpnDv+vK7OuYh+APeOZPmxXh/jeYlX9Z3IdkOtG8kytTF82neybRxJe8OljeP1dzES23isRzScPnp8hMuP9Xq456NhpJ9Tm26bAAAG7ElEQVR3rO2GWjeSZWrj6BiN9g23/njLRtreU+XVz3Co5dH273So5eHexmFF3OmjsWRmJW4EowpGsmhvY7S3D9TGaBEubYyGIwUv3et3AWMg2tsY7e0DtTFahEUbdaQgIiIDdKQgIiIDxk0omNn9ZlZtZltP4r1LzOwtM9trZj8yMwtZd6uZ7TKzbWZ25+hWfUI1jnr7zOwrZlZuZpuCjytHv/ITqtOTn2Fw/e1m5swsa/QqPnEe/Ry/ZmZbgj/DZ80sd/QrP6E6vWjjd8xsZ7Cdj5hZxuhXPuIavWjf+4J/Y/rMzNt+h1O9fClSHsBK4Cxg60m8901gBWDA/wJXBJdfBDwHJAZf50RZ+74C3O73z87LNgbXzQCeIXD/S1a0tRFIC9nmNuDnUdjGS4G44PNvA9+OsvadDswF1gBLvax/3BwpOOdeBupDl5lZkZk9bWbrzewVMztt8PvMbBqBX6rXXeCn82vgmuDq/wN8yznXGdxHtbetGJ5H7QsrHrbxLuBzgO8dbF600TnXFLJpCj6306M2Puuc6wluuhaY7m0rhudR+3Y453aNRf3jJhSGcS9wq3NuCXA78LMhtskDykJelwWXAcwBLjCzN8zsJTM729NqT9yptg/gluAh+f1mluldqSftlNpoZlcD5c65zV4XegpO+edoZl83s8PAB4AveVjryRqNf6v9PkrgW3Y4Gc32eWrcztFsZqnAucAfQ04vJw616RDL+r9pxQGZwDnA2cAfzGxWMOV9NUrtuxv4WvD114DvEfiFCwun2kYzSwa+SODUQ1gapZ8jzrkvAl80sy8AtwBfHuVST9potTH4WV8EeoAHRrPGUzGa7RsL4zYUCBwlNTrnFoUuNLNYYH3w5WME/jCGHopOByqCz8uAPwVD4E0z6yMwfkmNl4WP0Cm3zzlXFfK+/wKe8LLgk3CqbSwCCoHNwV/W6cAGM1vmnKv0uPaRGo1/p6F+BzxJGIUCo9RGM7sReA9wcTh8MQsx2j9Db/nVGePHAyggpPMHeA14X/C5AQuHed86AkcD/Z0/VwaXfxK4I/h8DnCY4L0fUdK+aSHbfBp4MNp+hoO2KcXnjmaPfo6zQ7a5FXgoCtt4ObAdyPa7bV60L2T9GjzuaPb9f94Y/pB+DxwBugl8w/8YgW+JTwObg/+gvjTMe5cCW4F9wE/6//ADCcBvg+s2AO+Ksvb9BngL2ELgm8y0sWrPWLVx0Da+h4JHP8eHg8u3EBgbJy8K27iXwJeyTcGHb1dYedS+a4Of1QlUAc94Vb/uaBYRkQHj/eojEREJoVAQEZEBCgURERmgUBARkQEKBRERGaBQkKhgZi1jvL/7zGzeKH1Wb3AE061m9vjxRvg0swwz+5fR2LfIYLokVaKCmbU451JH8fPi3NsDrHkqtHYz+29gt3Pu68fYvgB4wjm3YCzqk/FFRwoStcws28weNrN1wcd5weXLzOw1M9sY/O/c4PIPm9kfzexx4FkzW2Vma8zsoeBY/Q+EjG+/pn9cezNrCQ44t9nM1prZlODyouDrdWZ2xwiPZl7n7cH6Us3seTPbYIEx9lcHt/kWUBQ8uvhOcNvPBvezxcy+Oor/G2WcUShINPshcJdz7mzgvcB9weU7gZXOucUERgz9Rsh7VgA3OufeFXy9GPgUMA+YBZw3xH5SgLXOuYXAy8DHQ/b/w+D+jzuGTXAsnIsJ3D0O0AFc65w7i8DcHd8LhtLngX3OuUXOuc+a2aXAbGAZsAhYYmYrj7c/kaGM5wHxJPpdAswLGZkyzcwmAunAf5vZbAKjUMaHvOcvzrnQsfDfdM6VAZjZJgJj2rw6aD9dvD1Y4Hrg3cHnK3h73obfAd8dps4JIZ+9HvhLcLkB3wj+ge8jcAQxZYj3Xxp8bAy+TiUQEi8Psz+RYSkUJJrFACucc+2hC83sx8CLzrlrg+fn14Ssbh30GZ0hz3sZ+nem273dOTfcNsfS7pxbZGbpBMLlZuBHBOY+yAaWOOe6zawUSBri/QZ80zl3zwnuV+QddPpIotmzBOYOAMDM+ocuTgfKg88/7OH+1xI4bQVw/fE2ds4dJTBd5u1mFk+gzupgIFwEzAxu2gxMDHnrM8BHg+P2Y2Z5ZpYzSm2QcUahINEi2czKQh6fIfAHdmmw83U7gaHOAe4EvmlmfwViPazpU8BnzOxNYBpw9HhvcM5tJDCS5vUEJopZamYlBI4adga3qQP+GryE9TvOuWcJnJ563czeAh7ib0NDZMR0SaqIR4Izu7U755yZXQ/c4Jxbfbz3ifhJfQoi3lkC/CR4xVAjYTSVqchwdKQgIiID1KcgIiIDFAoiIjJAoSAiIgMUCiIiMkChICIiAxQKIiIy4P8DCyC53Zy79pIAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "learner.recorder.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0.395028</td>\n",
       "      <td>0.267539</td>\n",
       "      <td>02:13</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.099129</td>\n",
       "      <td>0.047389</td>\n",
       "      <td>02:18</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.037288</td>\n",
       "      <td>0.025490</td>\n",
       "      <td>02:17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.023574</td>\n",
       "      <td>0.016976</td>\n",
       "      <td>02:18</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.025166</td>\n",
       "      <td>0.009883</td>\n",
       "      <td>02:27</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.016545</td>\n",
       "      <td>0.006756</td>\n",
       "      <td>02:19</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>0.009579</td>\n",
       "      <td>0.005723</td>\n",
       "      <td>02:25</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>0.006136</td>\n",
       "      <td>0.004856</td>\n",
       "      <td>02:27</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>0.006793</td>\n",
       "      <td>0.004263</td>\n",
       "      <td>02:33</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>0.006496</td>\n",
       "      <td>0.004238</td>\n",
       "      <td>02:40</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learner.fit_one_cycle(config.epochs, max_lr=config.max_lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_preds_as_nparray(ds_type) -> np.ndarray:\n",
    "    \"\"\"\n",
    "    the get_preds method does not yield the elements in order by default\n",
    "    we borrow the code from the RNNLearner to resort the elements into their correct order\n",
    "    \"\"\"\n",
    "    preds = learner.get_preds(ds_type)[0].detach().cpu().numpy()\n",
    "    sampler = [i for i in databunch.dl(ds_type).sampler]\n",
    "    reverse_sampler = np.argsort(sampler)\n",
    "    return preds[reverse_sampler, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_preds = get_preds_as_nparray(DatasetType.Valid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 113,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_valid = np.argmax(val[label_cols].values, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 114,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9972469847928683"
      ]
     },
     "execution_count": 114,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "accuracy_score(y_valid, np.argmax(val_preds, axis=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_preds = get_preds_as_nparray(DatasetType.Test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(\"model_bert_base_prob.npy\", test_preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_outcome = np.argmax(test_preds, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 118,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1    1183\n",
      "2     822\n",
      "0     422\n",
      "3     321\n",
      "dtype: int64\n"
     ]
    }
   ],
   "source": [
    "sample_submission = pd.read_excel(DATA_ROOT / \"Sample_submission.xlsx\", index_col=None)\n",
    "sample_submission['SECTION'] = test_outcome\n",
    "print(pd.Series(test_outcome).value_counts())\n",
    "sample_submission.to_excel(\"model_bert_base.xlsx\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
